import pandas as pd
import numpy as np
from urllib import request
import gzip
import os

# 定义 MNIST 数据集下载地址
# url_train_images = "http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz"
# url_train_labels = "http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz"
# 外网地址下载较慢，上传到国内gitee，加快数据集下载
os.system("git clone https://gitee.com/yunmeng365524/anolis-sys-tests-res.git")

# 定义 MNIST 数据集保存路径
path_train_images = "./data/train-images.gz"
path_train_labels = "./data/train-labels.gz"
path_train = "./data/train.csv"

# 下载 MNIST 数据集
if not os.path.exists(path_train_images):
    # request.urlretrieve(url_train_images, path_train_images)
    os.system("cp ./anolis-sys-tests-res/res/train-images-idx3-ubyte.gz ./data/train-images.gz")
if not os.path.exists(path_train_labels):
    # request.urlretrieve(url_train_labels, path_train_labels)
    os.system("cp ./anolis-sys-tests-res/res/train-labels-idx1-ubyte.gz ./data/train-labels.gz")

# 解压缩 MNIST 数据集
with gzip.open(path_train_images, "rb") as f_in:
    with open(path_train_images[:-3], "wb") as f_out:
        f_out.write(f_in.read())
with gzip.open(path_train_labels, "rb") as f_in:
    with open(path_train_labels[:-3], "wb") as f_out:
        f_out.write(f_in.read())

# 读取 MNIST 数据集
with open(path_train_images[:-3], "rb") as f:
    # 跳过文件头
    f.read(16)
    # 读取图像数据
    train_images = np.frombuffer(f.read(), dtype=np.uint8).reshape(-1, 28*28)

with open(path_train_labels[:-3], "rb") as f:
    # 跳过文件头
    f.read(8)
    # 读取标签数据
    train_labels = np.frombuffer(f.read(), dtype=np.uint8)

# 保存到 CSV 文件中
pd.DataFrame(np.hstack([train_labels.reshape(-1, 1), train_images]), columns=["label"] + list(range(784))).to_csv(path_train, index=False)
